{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataloaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.data.dataloader import build_dataloader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is an example [dataset](./datasets.ipynb) to load."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from chemprop.data import MoleculeDatapoint, MoleculeDataset\n",
    "\n",
    "smis = [\"C\" * i for i in range(1, 4)]\n",
    "ys = np.random.rand(len(smis), 1)\n",
    "dataset = MoleculeDataset([MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Torch dataloaders"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Chemprop uses native `torch.utils.data.Dataloader`s to batch data as input to a model. `build_dataloader` is a helper function to make the dataloader."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = build_dataloader(dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`build_dataloader` changes the defaults of `Dataloader` to use a batch size of 64 and turn on shuffling. It also automatically uses the correct collating function for the dataset (single component vs multi-component)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from chemprop.data.collate import collate_batch, collate_multicomponent\n",
    "\n",
    "dataloader = DataLoader(dataset=dataset, batch_size=64, shuffle=True, collate_fn=collate_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Collate function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The collate function takes an iterable of dataset outputs and batches them together. Iterating through batches is done automatically during training by the lightning `Trainer`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TrainingBatch(bmg=<chemprop.data.collate.BatchMolGraph object at 0x107741850>, V_d=None, X_d=None, Y=tensor([[0.0562],\n",
       "        [0.5048]]), w=tensor([[1.],\n",
       "        [1.]]), lt_mask=None, gt_mask=None)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "collate_batch([dataset[0], dataset[1]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Shuffling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Shuffling the data helps improve model training, so `build_dataloader` has `shuffle=True` as the default. Shuffling should be turned off for validation and test dataloaders. Lightning gives a warning if a dataloader with shuffling is used during prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = build_dataloader(dataset)\n",
    "val_loader = build_dataloader(dataset, shuffle=False)\n",
    "test_loader = build_dataloader(dataset, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (mps), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "from lightning import pytorch as pl\n",
    "from chemprop import models, nn\n",
    "\n",
    "trainer = pl.Trainer(logger=False, enable_checkpointing=False, max_epochs=1)\n",
    "chemprop_model = models.MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), nn.RegressionFFN())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:475: Your `predict_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.\n",
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicting DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  3.37it/s]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/brianli/Documents/chemprop/chemprop/nn/message_passing/base.py:263: UserWarning: The operator 'aten::scatter_reduce.two_out' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:13.)\n",
      "  M_all = torch.zeros(len(bmg.V), H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(\n"
     ]
    }
   ],
   "source": [
    "preds = trainer.predict(chemprop_model, dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicting DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 164.67it/s]\n"
     ]
    }
   ],
   "source": [
    "preds = trainer.predict(chemprop_model, test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parallel data loading"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As datapoints are sampled from the dataset, the `MolGraph` data structures are generated on-the-fly, which requires featurization of the molecular graphs. Giving the dataloader multiple workers can increase dataloading speed by preparing the datapoints in parallel. Note that this is not compatible with Windows (the process hangs) and some versions of Mac. \n",
    "\n",
    "[Caching](./dataloaders.ipynb) the the `MolGraphs` in the dataset before making the dataloader can also speed up sequential dataloading (`num_workers=0`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.utils.data.dataloader.DataLoader at 0x28bdc1510>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "build_dataloader(dataset, num_workers=8)\n",
    "\n",
    "dataset.cache = True\n",
    "build_dataloader(dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Drop last batch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`build_dataloader` drops the last batch if it is a single datapoint as batch normalization (the default) requires at least two data points. If you do not want to drop the last datapoint, you can adjust the batch size, or, if you aren't using batch normalization, build the dataloader manually."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Dropping last batch of size 1 to avoid issues with batch normalization (dataset size = 3, batch_size = 2)\n"
     ]
    }
   ],
   "source": [
    "dataloader = build_dataloader(dataset, batch_size=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = build_dataloader(dataset, batch_size=3)\n",
    "dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, collate_fn=collate_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Samplers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The default sampler for a `torch.utils.data.Dataloader` is a `torch.utils.data.sampler.SequentialSampler` for `shuffle=False`, or a `torch.utils.data.sampler.RandomSampler` if `shuffle=True`. \n",
    "\n",
    "`build_dataloader` can be given a seed to make a `chemprop.data.samplers.SeededSampler` for reproducibility. Chemprop also offers `chemprop.data.samplers.ClassSampler` to equally sample positive and negative classes for binary classification tasks. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.utils.data.dataloader.DataLoader at 0x28d99a8d0>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "build_dataloader(dataset, seed=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1.],\n",
      "        [0.],\n",
      "        [1.],\n",
      "        [0.],\n",
      "        [1.],\n",
      "        [0.],\n",
      "        [1.],\n",
      "        [0.]])\n"
     ]
    }
   ],
   "source": [
    "smis = [\"C\" * i for i in range(1, 11)]\n",
    "ys = np.random.randint(low=0, high=2, size=(len(smis), 1))\n",
    "dataset = MoleculeDataset([MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)])\n",
    "\n",
    "dataloader = build_dataloader(dataset, class_balance=True)\n",
    "\n",
    "_, _, _, Y, *_ = next(iter(dataloader))\n",
    "print(Y)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
